5555
20228
Wat is de gemakkelijkste manier om de tensor van vorm (batch_grootte, hoogte, breedte) gevuld met n waarden om te zetten in tensor van vorm (batch_size, n, hoogte, breedte)?
Ik heb de onderstaande oplossing gemaakt, maar het lijkt erop dat er een gemakkelijkere en snellere manier is om dit te doen
def batch_tensor_to_onehot (tnsr, klassen):
tnsr = tnsr.unsqueeze (1)
res = []
voor cls in bereik (klassen):
res.append ((tnsr == cls) .long ())
return torch.cat (res, dim = 1) 
U kunt torch.nn.functional.one_hot gebruiken.
Voor jouw geval:
a = torch.nn.functional.one_hot (tnsr, num_classes = klassen)
out = a.permute (0, 3, 1, 2)
​
Je zou ook Tensor.scatter_ kunnen gebruiken, die .permute vermijdt, maar aantoonbaar moeilijker te begrijpen is dan de eenvoudige methode die wordt voorgesteld door @Alpha.
def batch_tensor_to_onehot (tnsr, klassen):
result = torch.zeros (tnsr.shape [0], classes, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device)
result.scatter_ (1, tnsr.unsqueeze (1), 1)
resultaat retourneren
Benchmarking resultaten
Ik was nieuwsgierig en besloot de drie benaderingen te benchmarken. Ik ontdekte dat er geen significant relatief verschil lijkt te zijn tussen de voorgestelde methoden met betrekking tot batchgrootte, breedte of hoogte. Vooral het aantal klassen was de onderscheidende factor. Zoals bij elke benchmark kan de kilometerstand natuurlijk variëren.
De benchmarks werden verzameld met behulp van willekeurige indices en met batchgrootte, hoogte, breedte = 100. Elk experiment werd 20 keer herhaald, waarbij het gemiddelde werd gerapporteerd. Het experiment num_classes = 100 wordt eenmaal uitgevoerd voordat profilering voor opwarming plaatsvindt.
De CPU-resultaten laten zien dat de oorspronkelijke methode waarschijnlijk het beste was voor num_classes kleiner dan ongeveer 30, terwijl voor GPU de scatter_-benadering de snelste lijkt te zijn.
Tests uitgevoerd op Ubuntu 18.04, NVIDIA 2060 Super, i7-9700K
De code die wordt gebruikt voor benchmarking wordt hieronder gegeven:
fakkel importeren
van tqdm import tqdm
import tijd
importeer matplotlib.pyplot als plt
def batch_tensor_to_onehot_slavka (tnsr, klassen):
tnsr = tnsr.unsqueeze (1)
res = []
voor cls in bereik (klassen):
res.append ((tnsr == cls) .long ())
return torch.cat (res, dim = 1)
def batch_tensor_to_onehot_alpha (tnsr, klassen):
result = torch.nn.functional.one_hot (tnsr, num_classes = klassen)
retourneer resultaat. permute (0, 3, 1, 2)
def batch_tensor_to_onehot_jodag (tnsr, klassen):
result = torch.zeros (tnsr.shape [0], classes, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device)
result.scatter_ (1, tnsr.unsqueeze (1), 1)
resultaat retourneren
def main ():
aantal_klassen = [2, 10, 25, 50, 100]
hoogte = 100
breedte = 100
bs = [100] * 20
voor d in ['cpu', 'cuda']:
times_slavka = []
times_alpha = []
times_jodag = []
warmup = waar
voor c in tqdm ([aantal_klassen [-1]] + aantal_klassen, ncols = 0):
tslavka = 0
talpha = 0
tjodag = 0
voor b in bs:
tnsr = torch.randint (c, (b, hoogte, breedte)). to (device = d)
t0 = tijd.tijd ()
y = batch_tensor_to_onehot_slavka (tnsr, c)
torch.cuda.synchronize ()
tslavka + = tijd.tijd () - t0
zo niet opwarmen:
times_slavka.append (tslavka / len (bs))
voor b in bs:
tnsr = torch.randint (c, (b, hoogte, breedte)). to (device = d)
t0 = tijd.tijd ()
y = batch_tensor_to_onehot_alpha (tnsr, c)
torch.cuda.synchronize ()
talpha + = tijd.tijd () - t0
zo niet opwarmen:
times_alpha.append (talpha / len (bs))
voor b in bs:
tnsr = torch.randint (c, (b, hoogte, breedte)). to (device = d)
t0 = tijd.tijd ()
y = batch_tensor_to_onehot_jodag (tnsr, c)
torch.cuda.synchronize ()
tjodag + = tijd.tijd () - t0
zo niet opwarmen:
times_jodag.append (tjodag / len (bs))
warmup = False
fig = plt.figure ()
ax = fig.subplots ()
ax.plot (num_classes, times_slavka, label = 'Slavka-cat')
ax.plot (num_classes, times_alpha, label = 'Alpha-one_hot')
ax.plot (aantal_klassen, tijden_jodag, label = 'jodag-scatter_')
ax.set_xlabel ('aantal_klassen')
ax.set_ylabel ('tijd (en)')
ax.set_title (f '{d} benchmark')
ax.legend ()
plt.savefig (f '{d} .png')
plt.show ()
if __name__ == "__main__":
hoofd()
​
Uw antwoord
StackExchange.ifUsing ("editor", function () {
StackExchange.using ("externalEditor", function () {
StackExchange.using ("snippets", function () {
StackExchange.snippets.init ();
​
​
}, "code-snippets");
StackExchange.ready (function () {
var channelOptions = {
tags: "" .split (""),
id: "1"
​
initTagRenderer ("". split (""), "" .split (""), channelOptions);
StackExchange.using ("externalEditor", function () {
// Moet de editor na fragmenten activeren, als fragmenten zijn ingeschakeld
if (StackExchange.settings.snippets.snippetsEnabled) {
StackExchange.using ("snippets", function () {
createEditor ();
​
​
anders {
createEditor ();
​
​
functie createEditor () {
StackExchange.prepareEditor ({
useStacksEditor: false,
heartbeatType: 'antwoord',
autoActivateHeartbeat: false,
convertImagesToLinks: waar,
noModals: waar,
showLowRepImageUploadWarning: true,
ReputationToPostImages: 10,
bindNavPrevention: true,
postfix: "",
imageUploader: {
brandingHtml: "Aangedreven door \ u003ca href = \" https: //imgur.com/ \ "\ u003e \ u003csvg class = \" svg-icon \ "width = \" 50 \ "hoogte = \" 18 \ "viewBox = \ "0 0 50 18 \" fill = \ "none \" xmlns = \ "http: //www.w3.org/2000/svg \" \ u003e \ u003cpath d = \ "M46.1709 9.17788C46.1709 8.26454 46.2665 7.94324 47.1084 7.58816C47.4091 7.46349 47.7169 7.36433 48.0099 7.26993C48.9099 6.97997 49.672 6.73443 49.672 5.93063C49.672 5.22043 48.9832 4.61182 48.1414 4.61182C47.4335 4.61182 46.7256 4.91628 46.0943 5.000 43.1481 6.59048V11.9512C43.1481 13.2535 43.6264 13.8962 44.6595 13.8962C45.6924 13.8962 46.1709 13.253546.1709 11.9512V9.17788Z \ "/ \ u003e \ u003cpath d = \" M32.492 10.1419C32.492 12.6954 34.1182 14.0484 37.0451 14.0484C39.9723 14.0484 41.5985 12.6954 41.5985 10.1419V6.59049C41.5985 4.6954 41.5985 10.1419V6.59049C41.1985 38.5948 5.28821 38.5948 6.59049V9.60062C38.5948 10.8521 38.2696 11.5455 37.0451 11.5455C35.8209 11.5455 35.4954 10.8521 35.4954 9.60062V6.59049C35.4954 5.28821 35.0173 4.66232 34.00343 4.66232C32 " fill-rule = \ "evenodd \" clip-rule = \ "evenodd \" d = \ "M25.6622 17.6335C27.8049 17.6335 29.3739 16.9402 30.2537 15.6379C30.8468 14.7755 30.9615 13.5579 30.9615 11.9512V6.59049C30.9615 5.28821 30.4 29.4502 4.66231C28.9913 4.66231 28.4555 4.94978 28.1109 5.50789C27.499 4.86533 26.7335 4.56087 25.7005 4.56087C23.1369 4.56087 21.0134 6.57349 21.0134 9.27932C21.0134 11.9852 23.003 13.913 25.3754 2860.16.16.16.16.16.16.16.16.16.26.18.26 C28. 1256 12.8854 28,1301 12,9342 28,1301 12.983C28.1301 14,4373 27,2502 15,2321 25,777 15.2321C24.8349 15,2321 24,1352 14,9821 23,5661 14.7787C23.176 14,6393 22,8472 14,5218 22,5437 14.5218C21.7977 14,5218 21,2429 15,0123 21,2429 15.6887C21.2429 16,7375 22,9072 17,6335 25,6622 17.6335ZM24.1317 9,27932 C24.1317 7.94324 24.9928 7.09766 26.1024 7.09766C27.2119 7.09766 28.0918 7.94324 28.0918 9.27932C28.0918 10.6321 27.2311 11.5116 26.1024 11.5116C24.9737 11.5116 24.1317 10.6491 24.1317 9.27932Ze \ "/ \ u0016.803 d \" / \ u0016.8045 8045 13.2535 17.2637 13.8962 18.2965 13.8962C19.3298 13.8962 19.8079 13.2535 19.8079 11.9512V8.12928C19.8079 5.82936 18.4879 4.62866 16.4027 4.62866C15.1594 4.62866 14.279 4.98375 13.3609 5.88013C12.653 5.0515466 11.68.653 5.0515466 58314 4.9328 7.10506 4.66232 6.51203 4.66232C5.47873 4.66232 5.00066 5.28821 5.00066 6.59049V11.9512C5.00066 13.2535 5.47873 13.8962 6.51203 13.8962C7.54479 13.8962 8.0232 13 .2535 8.0232 11.9512V8.90741C8.0232 7.58817 8.44431 6.91179 9.53458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512C10.893 13.2535 11.3711 13.8962 12.4044 13.8962C13.4375 13.8962 13.8962C13.4375 13.8962 13.8962C13.4375 13.8962 13.8962C13.4375 C16.4027 6.91179 16.8045 7.58817 16.8045 8.94108V11.9512Z \ "/ \ u003e \ u003cpath d = \" M3.31675 6.59049C3.31675 5.28821 2.83866 4.66232 1.82471 4.66232C0.791758 4.66232 0.313354 5.28821 0.313354 6.59049.39.39.3 ... 1.82471 13.8962C2.85798 13.8962 3.31675 13.2535 3.31675 11.9512V6.59049Z \ "/ \ u003e \ u003cpath d = \" M1.87209 0.400291C0.843612 0.400291 0 1.1159 0 1.98861C0 2.87869 0.822846 3.57676 1.87867 3.57676 1.87209 3.57676 C3.7234 1.1159 2.90056 0.400291 1.87209 0.400291Z \ "fill = \" # 1BB76E \ "/ \ u003e \ u003c / svg \ u003e \ u003c / a \ u003e",
contentPolicyHtml: "Gebruikersbijdragen gelicentieerd onder \ u003ca href = \" https: //stackoverflow.com/help/licensing \ "\ u003ecc by-sa \ u003c / a \ u003e \ u003ca href = \" https://stackoverflow.com / legal / content-policy \ "\ u003e (contentbeleid) \ u003c / a \ u003e",
allowUrls: waar
​
onDemand: waar,
discardSelector: ".discard-answer"
, onmiddellijkShowMarkdownHelp: true, enableTables: true, enableSnippets: true
​
​
​
Bedankt voor het bijdragen aan een antwoord op Stack Overflow!
Zorg ervoor dat u de vraag beantwoordt. Geef details en deel uw onderzoek!
Maar vermijd ...
Om hulp vragen, opheldering vragen of reageren op andere antwoorden.
Uitspraken doen op basis van meningen; ondersteun ze met referenties of persoonlijke ervaring.
Bekijk onze tips voor het schrijven van goede antwoorden voor meer informatie.
Concept opgeslagen
Concept verwijderd
Meld u aan of log in
StackExchange.ready (function () {
StackExchange.helpers.onClickDraftSave ('# login-link');
​
Meld u aan met Google
Meld u aan met Facebook
Meld u aan met e-mail en wachtwoord
Verzenden
Post als gast
Naam
E-mail
Vereist, maar nooit getoond
StackExchange.ready (
functie () {
StackExchange.openid.initPostLogin ('. New-post-login', 'https% 3a% 2f% 2fstackoverflow.com% 2fquestions% 2f62245173% 2fpytorch-transform-tensor-to-one-hot% 23new-answer', 'question_page' );
​
​
Post als gast
Naam
E-mail
Vereist, maar nooit getoond
Plaats uw antwoord
Gooi weg
Door op "Plaats uw antwoord" te klikken, gaat u akkoord met onze servicevoorwaarden, privacybeleid en cookiebeleid
Niet het antwoord waar je naar zoekt? Blader door andere vragen met de tag python pytorch tensor one-hot-encoding of stel uw eigen vraag.